## Short Intro
This repo is used to explore the efficiency an accuracy of efficient linear layer in Transformer models.


### File Organization
```
Structured/src/
├── benchmark_acc/                          [training and evaluation entry for different dataset]               
│   └── refinedweb_experiment.py                [refinedweb entry]               
├── benchmark_eff                           [efficiency entry]
│   ├── bench_kernel.py                         [kernel efficiency]   
│   ├── bench_mlp.py                            [mlp efficiency]
│   ├── benchmark_model_infer.py                [decoding efficiency]
│   └── benchmark_model_train.py                [prefill efficiency]
├── configs                                 [hydra config]    
│   ├── data                                    [No use. refinedweb is preprocessed in advance]
│   ├── method                                  [different efficient linear layer]
│   ├── model                                   [gpt and llama]
│   ├── optimization                            [optimization including scheduler, optimizer, self-guided training etc.]
│   └── refinedweb_config.yaml                  
├── data                                    [No use]    
├── modules
│   ├── __init__.py
│   ├── op                                       [fast op. Commons ones invoke others or paste from megatron]
│   ├── layer                                    [efficient lineaer layers that invoke functions in op dir]
│   ├── mlp                                      [efficient mlps that invoke functions in layer dir]
│   └── model                                    [supports layernorm or rmsnorm, bias or not, tie we or not, rotary or absolute, gelu or swilu]
├── optimization
│   ├── __init__.py
│   ├── scheduler.py                             [cosine with warmup]
│   └── trainer.py                               [basic training function including seed, checkpoint, and info]
└── utils
    └── refinedweb_llama.py                      [preprocess file]
```

### Data preprocessing
python refinedweb_llama.py

Refinedweb is quite large. So we shuffle, extract, and tokenize them in advance. Their token ids are kept in np.memmap to avoid loading data into cpu memory at one time.


### Examples
```
# gpt2 and linear
torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=linear

# lowranklinear 
torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowranklinear method.kwargs.rank=256 optimization.max_tokens=${m_token} optimization.optimizer.kwargs.lr=3.0e-4 data.train.train_batch=16 data.test.test_batch=16

# self-guided training
torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2m method=lowranklinear method.kwargs.rank=256 optimization.max_tokens=${m_token} optimization.optimizer.kwargs.lr=3.0e-4 optimization.training.kwargs.reduce_flop=true data.train.train_batch=16 data.test.test_batch=16

# to match the flops
torchrun --nnodes=1 --nproc_per_node=1 refinedweb_experiment.py model=gpt2 method=lowranklinear method.kwargs.rank=192 optimization.max_tokens=2500000000 optimization.training.kwargs.mode=fixedflop optimization.training.kwargs.max_step_ratio=0.25 data.train.train_batch=32 data.test.test_batch=32
```
